#include <AggregateFunctions/IAggregateFunction.h>
#include <Columns/ColumnAggregateFunction.h>
#include <Core/Settings.h>
#include <DataTypes/DataTypeAggregateFunction.h>
#include <Functions/FunctionFactory.h>
#include <Functions/FunctionHelpers.h>
#include <Functions/IFunction.h>
#include <Interpreters/Context.h>
#include <Common/AlignedBuffer.h>
#include <Common/Arena.h>
#include <Common/scope_guard_safe.h>


namespace DB
{
namespace Setting
{
    extern const SettingsBool allow_deprecated_error_prone_window_functions;
}

namespace ErrorCodes
{
    extern const int ILLEGAL_COLUMN;
    extern const int ILLEGAL_TYPE_OF_ARGUMENT;
    extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
    extern const int DEPRECATED_FUNCTION;
}

namespace
{

/** runningAccumulate(agg_state) - takes the states of the aggregate function and returns a column with values,
  * are the result of the accumulation of these states for a set of columns lines, from the first to the current line.
  *
  * Quite unusual function.
  * Takes state of aggregate function (example runningAccumulate(uniqState(UserID))),
  *  and for each row of columns, return result of aggregate function on merge of states of all previous rows and current row.
  *
  * So, result of function depends on partition of data to columns and on order of data in columns.
  */
class FunctionRunningAccumulate : public IFunction
{
public:
    static constexpr auto name = "runningAccumulate";

    static FunctionPtr create(ContextPtr context)
    {
        if (!context->getSettingsRef()[Setting::allow_deprecated_error_prone_window_functions])
            throw Exception(
                ErrorCodes::DEPRECATED_FUNCTION,
                "Function {} is deprecated since its usage is error-prone (see docs)."
                "Please use proper window function or set `allow_deprecated_error_prone_window_functions` setting to enable it",
                name);

        return std::make_shared<FunctionRunningAccumulate>();
    }

    String getName() const override
    {
        return name;
    }

    bool isStateful() const override
    {
        return true;
    }

    bool isVariadic() const override { return true; }

    size_t getNumberOfArguments() const override { return 0; }

    bool isDeterministic() const override
    {
        return false;
    }

    bool isDeterministicInScopeOfQuery() const override
    {
        return false;
    }

    bool isSuitableForShortCircuitArgumentsExecution(const DataTypesWithConstInfo & /*arguments*/) const override { return true; }

    DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override
    {
        if (arguments.empty() || arguments.size() > 2)
            throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH,
                "Incorrect number of arguments of function {}. Must be 1 or 2.", getName());

        const DataTypeAggregateFunction * type = checkAndGetDataType<DataTypeAggregateFunction>(arguments[0].get());
        if (!type)
            throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
                            "Argument for function {} must have type AggregateFunction - state "
                            "of aggregate function.", getName());

        return type->getReturnType();
    }

    ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr &, size_t /*input_rows_count*/) const override
    {
        const ColumnAggregateFunction * column_with_states
            = typeid_cast<const ColumnAggregateFunction *>(&*arguments.at(0).column);

        if (!column_with_states)
            throw Exception(ErrorCodes::ILLEGAL_COLUMN, "Illegal column {} of first argument of function {}",
                    arguments.at(0).column->getName(), getName());

        ColumnPtr column_with_groups;

        if (arguments.size() == 2)
            column_with_groups = arguments[1].column;

        AggregateFunctionPtr aggregate_function_ptr = column_with_states->getAggregateFunction();
        const IAggregateFunction & agg_func = *aggregate_function_ptr;

        AlignedBuffer place(agg_func.sizeOfData(), agg_func.alignOfData());

        /// Will pass empty arena if agg_func does not allocate memory in arena
        std::unique_ptr<Arena> arena = agg_func.allocatesMemoryInArena() ? std::make_unique<Arena>() : nullptr;

        auto result_column_ptr = agg_func.getResultType()->createColumn();
        IColumn & result_column = *result_column_ptr;
        result_column.reserve(column_with_states->size());

        const auto & states = column_with_states->getData();

        bool state_created = false;
        SCOPE_EXIT_MEMORY_SAFE({
            if (state_created)
                agg_func.destroy(place.data());
        });

        size_t row_number = 0;
        for (const auto & state_to_add : states)
        {
            if (row_number == 0 || (column_with_groups && column_with_groups->compareAt(row_number, row_number - 1, *column_with_groups, 1) != 0))
            {
                if (state_created)
                {
                    agg_func.destroy(place.data());
                    state_created = false;
                }

                agg_func.create(place.data()); /// This function can throw.
                state_created = true;
            }

            agg_func.merge(place.data(), state_to_add, arena.get());
            agg_func.insertResultInto(place.data(), result_column, arena.get());

            ++row_number;
        }

        return result_column_ptr;
    }
};

}

REGISTER_FUNCTION(RunningAccumulate)
{
    FunctionDocumentation::Description description_runningAccumulate = R"(
Accumulates the states of an aggregate function for each row of a data block.

:::warning Deprecated
The state is reset for each new block of data.
Due to this error-prone behavior the function has been deprecated, and you are advised to use [window functions](/sql-reference/window-functions) instead.
You can use setting [`allow_deprecated_error_prone_window_functions`](/operations/settings/settings#allow_deprecated_error_prone_window_functions) to allow usage of this function.
:::
)";
    FunctionDocumentation::Syntax syntax_runningAccumulate = "runningAccumulate(agg_state[, grouping])";
    FunctionDocumentation::Arguments arguments_runningAccumulate = {
        {"agg_state", "State of the aggregate function.", {"AggregateFunction"}},
        {"grouping", "Optional. Grouping key. The state of the function is reset if the `grouping` value is changed. It can be any of the supported data types for which the equality operator is defined.", {"Any"}}
    };
    FunctionDocumentation::ReturnedValue returned_value_runningAccumulate = {"Returns the accumulated result for each row.", {"Any"}};
    FunctionDocumentation::Examples examples_runningAccumulate = {
    {
        "Usage example with initializeAggregation",
        R"(
WITH initializeAggregation('sumState', number) AS one_row_sum_state
SELECT
    number,
    finalizeAggregation(one_row_sum_state) AS one_row_sum,
    runningAccumulate(one_row_sum_state) AS cumulative_sum
FROM numbers(5);
        )",
        R"(
┌─number─┬─one_row_sum─┬─cumulative_sum─┐
│      0 │           0 │              0 │
│      1 │           1 │              1 │
│      2 │           2 │              3 │
│      3 │           3 │              6 │
│      4 │           4 │             10 │
└────────┴─────────────┴────────────────┘
        )"
    }
    };
    FunctionDocumentation::IntroducedIn introduced_in_runningAccumulate = {1, 1};
    FunctionDocumentation::Category category_runningAccumulate = FunctionDocumentation::Category::Other;
    FunctionDocumentation documentation_runningAccumulate = {description_runningAccumulate, syntax_runningAccumulate, arguments_runningAccumulate, returned_value_runningAccumulate, examples_runningAccumulate, introduced_in_runningAccumulate, category_runningAccumulate};

    factory.registerFunction<FunctionRunningAccumulate>(documentation_runningAccumulate);
}

}
